-
Notifications
You must be signed in to change notification settings - Fork 211
[OMNIML-2852] [2/n] Add Core Sparse Attention Infrastructure #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #527 +/- ##
==========================================
+ Coverage 74.64% 74.95% +0.31%
==========================================
Files 183 192 +9
Lines 18542 18939 +397
==========================================
+ Hits 13840 14196 +356
- Misses 4702 4743 +41 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
54bfe2c to
0ce1376
Compare
fc9d285 to
5d027e0
Compare
|
Hi @kaix-nv could you further split this code change? This PR has 3000+ lines of code change and many file moves |
|
|
||
|
|
||
| # Create registry for sparse attention modules | ||
| SparseAttentionRegistry = _DMRegistryCls("SparseAttention", SparseAttentionModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a single registry for all Sparsity algorithms and modes and then use top-level mts.sparsify(model, mode=...) so all algorithms (e.g. weight or attention sparsify) are invoked by single shared API instead of separate API per algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good advice. I'll submit a follow-up PR later.
tests/examples/llm_sparsity/attention_sparsity/test_attention_sparsity.py
Outdated
Show resolved
Hide resolved
tests/gpu/torch/sparsity/attention_sparsity/test_attention_sparsity_gpu.py
Outdated
Show resolved
Hide resolved
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py
Outdated
Show resolved
Hide resolved
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py
Outdated
Show resolved
Hide resolved
jy-yuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work on the overall architecture!
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
Outdated
Show resolved
Hide resolved
| total_blocks = ( | ||
| num_block_rows * (num_block_rows + 1) // 2 # Causal: N(N+1)/2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that means rows==columns? Which means we only have causal in self-attention, not cross-attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's for causal attention in prefill.
| "--backend", | ||
| type=str, | ||
| default="pytorch", | ||
| choices=["pytorch", "triton"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "triton" a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
| method = getattr(module, "_method", "unknown") | ||
| threshold = getattr(module, "_threshold", "N/A") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do SparseAttentionModule have _method or _threshold?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print_sparse_attention_summary isn’t used in this PR, I’ve removed it. It will be introduced in the next PR.
| def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): | ||
| """Restore sparse attention state from state dict. | ||
|
|
||
| Args: | ||
| model: Model with sparse attention modules | ||
| state_dict: Saved state dictionary | ||
| """ | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, SparseAttentionModule): | ||
| module_name = get_unwrapped_name(name, model) | ||
| if module_name in state_dict: | ||
| module_state = state_dict[module_name] | ||
|
|
||
| # Restore method and config | ||
| if "method" in module_state: | ||
| module._method = module_state["method"] | ||
| if "method_config" in module_state: | ||
| # Restore config attributes | ||
| for key, val in module_state["method_config"].items(): | ||
| setattr(module, f"_{key}", val) | ||
|
|
||
| # Re-setup with restored config | ||
| module._setup() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need add test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_restore_sparse_attention_model covers the test for this func.
cd6fce2 to
0ca4d20
Compare
|
@kevalmorabia97 I've addressed the review suggestions. Could you please review and approve the PR so I can move forward with the subsequent PRs? Thanks. |
Signed-off-by: Kai Xu <[email protected]>
Signed-off-by: Kai Xu <[email protected]>
0ca4d20 to
02182f8
Compare
kevalmorabia97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only reviewed high-level structure. Would suggest someone else take a look at the core sparsity logic as I'm not familiar with it
Signed-off-by: Kai Xu <[email protected]>
02182f8 to
8acf333
Compare
|
Looks great! Should we have a simpler high-level usage which aligns with |
What does this PR do?
Type of change: ?
New feature
Overview: ?
This PR provides a sparse attention support in ModelOpt for applying attention sparsity through skip softmax method, enabling inference speedups for LLMs.
Key Features:
Design doc
Usage
Testing
Unit Test
ALL PASSED.
Accuracy
Benchmark: MMLU
Model: Qwen/Qwen3-4B
Cmd: python mmlu.py --model_name causal --model_path Qwen/Qwen3-4B --sparse_cfg SKIP_SOFTMAX_DEFAULT
Before your PR is "Ready for review"
Additional Information